{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_TzQEAkRJDNS"
   },
   "source": [
    "# VolPy pipeline for processing voltage imaging data \n",
    "The processing pipeline includes motion correction, memory mapping, segmentation, denoising and source extraction. The demo shows how to construct the params, MotionCorrect and VOLPY objects and call the relevant functions. \n",
    "Dataset courtesy of Karel Svoboda Lab (Janelia Research Campus)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "cellView": "both",
    "id": "NzHIeSOVHRm3",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from base64 import b64encode\n",
    "import cv2\n",
    "import glob\n",
    "import h5py\n",
    "import imageio\n",
    "from IPython import get_ipython\n",
    "from IPython.display import HTML, display, clear_output\n",
    "import logging\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import torch\n",
    "\n",
    "try:\n",
    "    cv2.setNumThreads(0)\n",
    "except:\n",
    "    pass\n",
    "\n",
    "import caiman as cm\n",
    "from caiman.motion_correction import MotionCorrect\n",
    "from caiman.utils.utils import download_demo, download_model\n",
    "from caiman.source_extraction.volpy import utils\n",
    "from caiman.source_extraction.volpy.volparams import volparams\n",
    "from caiman.source_extraction.volpy.volpy import VOLPY\n",
    "from caiman.summary_images import local_correlations_movie_offline, mean_image\n",
    "from caiman.paths import caiman_datadir\n",
    "\n",
    "logfile = None # Replace with a path if you want to log to a file\n",
    "logger = logging.getLogger('caiman')\n",
    "# Set to logging.INFO if you want much output, potentially much more output\n",
    "logger.setLevel(logging.ERROR)\n",
    "logfmt = logging.Formatter('%(relativeCreated)12d [%(filename)s:%(funcName)20s():%(lineno)s] [%(process)d] %(message)s')\n",
    "if logfile is not None:\n",
    "    handler = logging.FileHandler(logfile)\n",
    "else:\n",
    "    handler = logging.StreamHandler()\n",
    "handler.setFormatter(logfmt)\n",
    "logger.addHandler(handler)\n",
    "# Set play_movies to False if you want to disable play of movies, e.g. for remote-hosted Jupyter environments\n",
    "play_movies = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load demo movie and ROIs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EW54aHS_HRnE",
    "outputId": "6fd13521-983f-4711-ac76-fc3a601075f4",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# File path to movie file (will download if not present)\n",
    "fnames = download_demo('demo_voltage_imaging.hdf5') \n",
    "# File path to ROIs file (will download if not present)\n",
    "path_ROIs = download_demo('demo_voltage_imaging_ROIs.hdf5')  \n",
    "file_dir = os.path.split(fnames)[0]\n",
    "\n",
    "# Setup some parameters for data and motion correction dataset parameters\n",
    "fr = 400                                        # sample rate of the movie\n",
    "ROIs = None                                     # Region of interests\n",
    "index = None                                    # index of neurons\n",
    "weights = None                                  # reuse spatial weights by \n",
    "                                                # opts.change_params(params_dict={'weights':vpy.estimates['weights']})\n",
    "# Motion correction parameters\n",
    "pw_rigid = False                                # flag for pw-rigid motion correction\n",
    "gSig_filt = (3, 3)                              # size of filter, in general gSig (see below),\n",
    "                                                # change this one if algorithm does not work\n",
    "max_shifts = (5, 5)                             # maximum allowed rigid shift\n",
    "strides = (48, 48)                              # start a new patch for pw-rigid motion correction every x pixels\n",
    "overlaps = (24, 24)                             # overlap between patches (size of patch strides+overlaps)\n",
    "max_deviation_rigid = 3                         # maximum deviation allowed for patch with respect to rigid shifts\n",
    "border_nan = 'copy'\n",
    "\n",
    "opts_dict = {\n",
    "    'fnames': fnames,\n",
    "    'fr': fr,\n",
    "    'index': index,\n",
    "    'ROIs': ROIs,\n",
    "    'weights': weights,\n",
    "    'pw_rigid': pw_rigid,\n",
    "    'max_shifts': max_shifts,\n",
    "    'gSig_filt': gSig_filt,\n",
    "    'strides': strides,\n",
    "    'overlaps': overlaps,\n",
    "    'max_deviation_rigid': max_deviation_rigid,\n",
    "    'border_nan': border_nan\n",
    "}\n",
    "\n",
    "opts = volparams(params_dict=opts_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true,
     "base_uri": "https://localhost:8080/"
    },
    "id": "_HqUNdCyPBBW",
    "outputId": "ea7cfec7-8fe9-43cd-8ea0-ab32b6befbd6",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Display the movie\n",
    "m_orig = cm.load(fnames)\n",
    "ds_ratio = 0.2\n",
    "moviehandle = m_orig.resize(1, 1, ds_ratio)\n",
    "min_, max_ = np.min(moviehandle), np.max(moviehandle)\n",
    "moviehandle = cm.movie((moviehandle-min_)/(max_-min_)*255,dtype='uint8')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if play_movies:\n",
    "    moviehandle.play(fr=40, q_max=99.5, magnification=4)  # press q to exit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "F6wKKAaeK1V_",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Start a cluster for parallel processing\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Motion Correction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LQBBg_xb13rp",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Create a motion correction object with the specified parameters\n",
    "mc = MotionCorrect(fnames, dview=dview, **opts.get_group('motion'))\n",
    "mc.motion_correct(save_movie=True)\n",
    "dview.terminate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "MQ8BDHESHRnh",
    "outputId": "f66a77ea-4c00-415f-a002-41630308fd1d"
   },
   "outputs": [],
   "source": [
    "# Motion correction compared to original movie\n",
    "m_orig = cm.load(fnames)\n",
    "m_rig = cm.load(mc.mmap_file)\n",
    "m_orig.fr = 400\n",
    "m_rig.fr = 400\n",
    "ds_ratio = 0.2\n",
    "moviehandle = cm.concatenate([m_orig.resize(1, 1, ds_ratio) - mc.min_mov * mc.nonneg_movie,\n",
    "                              m_rig.resize(1, 1, ds_ratio)], axis=2)\n",
    "min_, max_ = np.min(moviehandle), np.max(moviehandle)\n",
    "moviehandle = cm.movie((moviehandle-min_)/(max_-min_)*255,dtype='uint8')\n",
    "if play_movies:\n",
    "    moviehandle.play(fr=40, q_max=99.5, magnification=4)  # press q to exit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "SdeR6igh2cAd",
    "outputId": "73e0a0c3-a76b-4b9b-8c91-679ee80e4475"
   },
   "outputs": [],
   "source": [
    "# Movie subtracted from the baseline\n",
    "m_rig2 = m_rig.computeDFF(secsWindow=1)[0][:1000]\n",
    "moviehandle1 = -m_rig2\n",
    "min_, max_ = np.min(moviehandle1), np.max(moviehandle1)\n",
    "moviehandle1 = cm.movie((moviehandle1-min_)/(max_-min_)*255,dtype='uint8')\n",
    "if play_movies:\n",
    "    moviehandle1.play(fr=40, q_max=99.5, magnification=4)  # press q to exit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Memory Mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "tpJDa20bHRnp",
    "outputId": "fd1ef494-77f0-4f4c-98bf-9db08ed81c61"
   },
   "outputs": [],
   "source": [
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)\n",
    "border_to_0 = 0 if mc.border_nan == 'copy' else mc.border_to_0\n",
    "fname_new = cm.save_memmap_join(mc.mmap_file, base_name='memmap_',\n",
    "                           add_to_mov=border_to_0, dview=dview, n_chunks=10)\n",
    "dview.terminate()\n",
    "\n",
    "# Change fnames to the new motion corrected one\n",
    "opts.change_params(params_dict={'fnames': fname_new})    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "TLSPvv5Nt-u-"
   },
   "outputs": [],
   "source": [
    "if 'dview' in locals():\n",
    "    cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Segmentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "yBTRlgBunb4Q",
    "outputId": "ebd990ad-3041-4e5d-ee64-51995a23150b"
   },
   "outputs": [],
   "source": [
    "# Create mean and correlation images\n",
    "img = mean_image(mc.mmap_file[0], window = 1000, dview=dview)\n",
    "img = (img-np.mean(img))/np.std(img)\n",
    "\n",
    "gaussian_blur = False        # Use gaussian blur when there is too much noise in the video\n",
    "Cn = local_correlations_movie_offline(mc.mmap_file[0], fr=fr, window=fr*4, \n",
    "                                      stride=fr*4, winSize_baseline=fr, \n",
    "                                      remove_baseline=True, gaussian_blur=gaussian_blur,\n",
    "                                      dview=dview).max(axis=0)\n",
    "img_corr = (Cn-np.mean(Cn))/np.std(Cn)\n",
    "summary_images = np.stack([img, img, img_corr], axis=0).astype(np.float32)\n",
    "# Save summary images which could be further used in the VolPy GUI\n",
    "cm.movie(summary_images).save(fnames[:-5] + '_summary_images.tif')\n",
    "\n",
    "fig, axs = plt.subplots(1, 2)\n",
    "axs[0].imshow(summary_images[0]); axs[1].imshow(summary_images[2])\n",
    "axs[0].set_title('mean image'); axs[1].set_title('corr image')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "yUDya-p2HRnu"
   },
   "outputs": [],
   "source": [
    "use_maskrcnn = True  # set to True to predict the ROIs using the mask R-CNN\n",
    "if not use_maskrcnn:                 # use manual annotations\n",
    "    with h5py.File(path_ROIs, 'r') as fl:\n",
    "        ROIs = fl['mov'][()]  # load ROIs\n",
    "else:\n",
    "    weights_path = download_model('mask_rcnn')    \n",
    "    ROIs = utils.mrcnn_inference_pytorch(img=summary_images.transpose([1, 2, 0]), \n",
    "                                        size_range=[5, 22],\n",
    "                                        weights_path=weights_path, \n",
    "                                        display_result=True) # size parameter decides size range of masks to be selected\n",
    "    cm.movie(ROIs).save(fnames[:-5] + '_mrcnn_ROIs.hdf5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 2)\n",
    "axs[0].imshow(summary_images[0]); axs[1].imshow(ROIs.sum(0))\n",
    "axs[0].set_title('mean image'); axs[1].set_title('masks')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "6z92CC4631AA"
   },
   "outputs": [],
   "source": [
    "# Restart cluster to clean up memory\n",
    "cm.stop_server(dview=dview)\n",
    "c, dview, n_processes = cm.cluster.setup_cluster(\n",
    "    backend='multiprocessing', n_processes=None, single_thread=False, maxtasksperchild=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Trace denoising and spike extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "wCNTy-LZHRn6"
   },
   "outputs": [],
   "source": [
    "# Parameters for trace denoising and spike extraction\n",
    "ROIs = ROIs                                   # region of interests\n",
    "index = list(range(len(ROIs)))                # index of neurons\n",
    "weights = None                                # if None, use ROIs for initialization; to reuse weights check reuse weights block \n",
    "\n",
    "template_size = 0.02                          # half size of the window length for spike templates, default is 20 ms \n",
    "context_size = 35                             # number of pixels surrounding the ROI to censor from the background PCA\n",
    "visualize_ROI = False                         # whether to visualize the region of interest inside the context region\n",
    "flip_signal = True                            # Important!! Flip signal or not, True for Voltron indicator, False for others\n",
    "hp_freq_pb = 1 / 3                            # parameter for high-pass filter to remove photobleaching\n",
    "clip = 100                                    # maximum number of spikes to form spike template\n",
    "threshold_method = 'adaptive_threshold'       # adaptive_threshold or simple \n",
    "min_spikes= 10                                # minimal spikes to be found\n",
    "pnorm = 0.5                                   # a variable deciding the amount of spikes chosen for adaptive threshold method\n",
    "threshold = 3                                 # threshold for finding spikes only used in simple threshold method, Increase the threshold to find less spikes\n",
    "do_plot = False                               # plot detail of spikes, template for the last iteration\n",
    "ridge_bg= 0.01                                # ridge regression regularizer strength for background removement, larger value specifies stronger regularization \n",
    "sub_freq = 20                                 # frequency for subthreshold extraction\n",
    "weight_update = 'ridge'                       # ridge or NMF for weight update\n",
    "n_iter = 2                                    # number of iterations alternating between estimating spike times and spatial filters\n",
    "\n",
    "opts_dict={'fnames': fname_new,\n",
    "            'ROIs': ROIs,\n",
    "            'index': index,\n",
    "            'weights': weights,\n",
    "            'template_size': template_size, \n",
    "            'context_size': context_size,\n",
    "            'visualize_ROI': visualize_ROI, \n",
    "            'flip_signal': flip_signal,\n",
    "            'hp_freq_pb': hp_freq_pb,\n",
    "            'clip': clip,\n",
    "            'threshold_method': threshold_method,\n",
    "            'min_spikes':min_spikes,\n",
    "            'pnorm': pnorm, \n",
    "            'threshold': threshold,\n",
    "            'do_plot':do_plot,\n",
    "            'ridge_bg':ridge_bg,\n",
    "            'sub_freq': sub_freq,\n",
    "            'weight_update': weight_update,\n",
    "            'n_iter': n_iter}\n",
    "\n",
    "opts.change_params(params_dict=opts_dict);    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "IAm_iznuHRoA"
   },
   "outputs": [],
   "source": [
    "vpy = VOLPY(n_processes=n_processes, dview=dview, params=opts)\n",
    "vpy.fit(n_processes=n_processes, dview=dview)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "L9Otq9-hJMMj"
   },
   "outputs": [],
   "source": [
    "# Visualize spatial footprints and traces\n",
    "print(np.where(vpy.estimates['locality'])[0])    # neurons that pass locality test\n",
    "idx = np.where(vpy.estimates['locality'] > 0)[0]\n",
    "utils.view_components(vpy.estimates, img_corr, idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reconstructed movie\n",
    "mv_all = utils.reconstructed_movie(vpy.estimates.copy(), fnames=mc.mmap_file,\n",
    "                                           idx=idx, scope=(0,1000), flip_signal=flip_signal)\n",
    "if play_movies:\n",
    "    mv_all.play(fr=40, magnification=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vpy.estimates['ROIs'] = ROIs\n",
    "save_name = f'volpy_{os.path.split(fnames)[1][:-5]}_{threshold_method}'\n",
    "np.save(os.path.join(file_dir, save_name), vpy.estimates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%% Reuse spatial weights extracted from previous video\n",
    "# set weights = reuse_weights in opts_dict dictionary\n",
    "if False:\n",
    "    estimates = np.load(os.path.join(file_dir, save_name+'.npy'), allow_pickle=True).item()\n",
    "    reuse_weights = []\n",
    "    for idx in range(ROIs.shape[0]):\n",
    "        coord = estimates['context_coord'][idx]\n",
    "        w = estimates['weights'][idx][coord[0][0]:coord[1][0]+1, coord[0][1]:coord[1][1]+1] \n",
    "        plt.figure(); plt.imshow(w);plt.colorbar(); plt.show()\n",
    "        reuse_weights.append(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "background_save": true
    },
    "id": "jv9zCNpbHRoO"
   },
   "outputs": [],
   "source": [
    "# Stop cluster and clean up log files\n",
    "cm.stop_server(dview=dview)\n",
    "log_files = glob.glob('*_LOG_*')\n",
    "for log_file in log_files:\n",
    "    os.remove(log_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "caiman_pytorch2",
   "language": "python",
   "name": "caiman_pytorch2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
